package com.haha.http.rpc.client;

import com.haha.http.rpc.common.Request;
import com.haha.http.rpc.common.Response;
import com.haha.http.rpc.common.trace.TraceContextHolder;
import com.haha.http.rpc.common.util.JsonUtils;
import com.haha.http.rpc.common.util.LogUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.web.client.RestTemplate;

import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.nio.charset.StandardCharsets;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * @AUTHOR zhangxiaofan07
 * @CREATE 2022-01-27 19:27
 */
public class RpcClientRequestHandler<T> implements FactoryBean<T> {
    private static final Logger logger = LoggerFactory.getLogger("RPC_CLIENT_DIGEST");

    private final String host;
    private final RestTemplate restTemplate;
    private Class<T> objectType;

    @Override
    public T getObject() {
        return (T) Proxy.newProxyInstance(Thread.currentThread().getContextClassLoader(),
                new Class[]{objectType},
                (proxy, method, args) -> invoke(method, args)
        );
    }

    private Object invoke(Method method, Object[] args) {
        {
            long startTime = System.currentTimeMillis();

            HttpEntity<Request> requestEntity;
            ResponseEntity<Response> responseEntity;

            Request request = null;
            Response response = null;

            try {
                String traceId = TraceContextHolder.getTraceId();

                HttpHeaders headers = new HttpHeaders();
                headers.setContentType(MediaType.APPLICATION_JSON);

                request = new Request();
                request.setObjType(objectType);
                request.setTraceId(traceId);
                request.setMethodName(method.getName());

                Map<Class<?>, Object> paramMap = new LinkedHashMap<>();
                if (args != null) {
                    Class<?>[] parameterTypes = method.getParameterTypes();
                    int minLen = Math.min(args.length, parameterTypes.length);
                    for (int i = 0; i < minLen; i++) {
                        paramMap.put(parameterTypes[i], args[i]);
                    }
                    request.setParamMap(paramMap);
                }

                requestEntity = new HttpEntity<>(request, headers);

                String url = "http://" + host + "/rpc";
                responseEntity = restTemplate.postForEntity(url, requestEntity, Response.class);

                response = responseEntity.getBody();
                if (response == null || !response.isSuccess()) {
                    return null;
                }

                if (method.getReturnType().equals(void.class)) {
                    return null;
                } else {
                    return JsonUtils.toObject(response.getData(), method.getGenericReturnType());
                }
            } catch (Exception e) {
                LogUtils.error(logger, "RPC CLIENT ERROR", e);
                throw e;
            } finally {
                long cost = System.currentTimeMillis() - startTime;
                TraceContextHolder.clear();
                logger.info("rpc client host = {}, request = {}, response = {}, cost = {}ms,",
                        host, request, response, cost);
            }
        }
    }

    @Override
    public Class<?> getObjectType() {
        return objectType;
    }

    public RpcClientRequestHandler(Class<T> objectType, String host) {
        this(objectType, host, 1000, 500);
    }

    public RpcClientRequestHandler(Class<T> objectType, String host, int connectTimeout, int readTimeout) {
        this.objectType = objectType;
        this.host = host;
        this.restTemplate = initRestTemplate(connectTimeout, readTimeout);
    }

    private RestTemplate initRestTemplate(int connectTimeout, int readTimeout) {
        SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
        requestFactory.setConnectTimeout(connectTimeout);
        requestFactory.setReadTimeout(readTimeout);

        RestTemplate restTemplate = new RestTemplate(requestFactory);
        // 获取http消息转换器，并将字符串相关的设置为utf8
        List<HttpMessageConverter<?>> httpMessageConverters = restTemplate.getMessageConverters();
        for (HttpMessageConverter<?> httpMessageConverter : httpMessageConverters) {
            if (httpMessageConverter instanceof StringHttpMessageConverter) {
                ((StringHttpMessageConverter) httpMessageConverter).setDefaultCharset(StandardCharsets.UTF_8);
            }
        }
        return restTemplate;
    }

    public void setObjectType(Class<T> objectType) {
        this.objectType = objectType;
    }
}
